Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Embedding gradient performance optimization on GPU #16355

Merged
merged 7 commits into from
Oct 5, 2019
Merged

Embedding gradient performance optimization on GPU #16355

merged 7 commits into from
Oct 5, 2019

Conversation

MoisesHer
Copy link
Contributor

Description

This PR includes a specific Embedding-backward operator for GPU.
Two new CUDA kernels have been implemented for improving the performance of the operator when using GPU.
According to our measurements on Volta GPUs, the previous version was taken 2.2ms,
whereas the new implementation takes 0.3ms, i.e. more than 7x speedup.

Checklist

Essentials

  • [X ] Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests already existed (tests/python/gpu/test_operator_gpu:test_embedding_with_type) and changes do not affect correctness.
  • Code is well-documented:
  • For new C++ functions in header files, their functionalities and arguments are documented.

Changes

  • [x ] Embedding-backward operator for GPU, test: tests/python/gpu/test_operator_gpu:test_embedding_with_type

@ptrendx ptrendx self-requested a review October 2, 2019 17:38
@MoisesHer MoisesHer changed the title Pr embedding gradient Embedding gradient operator on GPU Oct 2, 2019
@MoisesHer MoisesHer changed the title Embedding gradient operator on GPU Embedding gradient performance optimization on GPU Oct 2, 2019
@ptrendx
Copy link
Member

ptrendx commented Oct 2, 2019

@sxjscience FYI

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sxjscience
Copy link
Member

sxjscience commented Oct 5, 2019

Nice! LGTM. So the BinarySearch version of FindBounds has complexity O(|V| log |N|) where |V| is the vocabulary size and |N| is the number of indices. I guess our initial version (https://github.com/dmlc/mshadow/blob/bc49327a44650c3f2b427e953ff95d2c27566c04/mshadow/cuda/tensor_gpu-inl.cuh#L619-L672) has complexity O(|N|) for finding the boundaries. Thus, in some workloads (in which |N| is small), the O(N) version might be faster.

@ptrendx
Copy link
Member

ptrendx commented Oct 5, 2019

Moises did a performance comparison between the new version and both the old one and the old buggy one. The new kernel is faster than the old working version in all cases and ~same in speed as the buggy one. The biggest performance change is seen actually when changing how many different elements are in the input data (as small number of distinct elements limits parallelism in the backward pass).

@ptrendx ptrendx merged commit 8096421 into apache:master Oct 5, 2019
apeforest pushed a commit that referenced this pull request Nov 6, 2019
* Add Embedding backward Op for GPU

* Add some code documentation

* Use unnamed namespace for integer log2 function

* Fix lint issues

* Fix one more lint problem

* Remove unnecessary conditions ops

* Fix one more lint problem
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants